{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import json\n",
    "from keras.models import Model\n",
    "from keras.layers import Input\n",
    "from keras.layers.convolutional import Conv2D\n",
    "from keras import backend as K\n",
    "from collections import OrderedDict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def format_decimal(arr, places=6):\n",
    "    return [round(x * 10**places) / 10**places for x in arr]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "DATA = OrderedDict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### pipeline 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "random_seed = 1000\n",
    "data_in_shape = (8, 8, 2)\n",
    "\n",
    "layers = [\n",
    "    Conv2D(4, (3,3), strides=(1,1), padding='valid', data_format='channels_last', activation='relu', use_bias=True),\n",
    "    Conv2D(4, (3,3), strides=(1,1), padding='valid', data_format='channels_last', activation='relu', use_bias=True),\n",
    "    Conv2D(4, (3,3), strides=(1,1), padding='valid', data_format='channels_last', activation='relu', use_bias=True)\n",
    "]\n",
    "\n",
    "input_layer = Input(shape=data_in_shape)\n",
    "x = layers[0](input_layer)\n",
    "for layer in layers[1:-1]:\n",
    "    x = layer(x)\n",
    "output_layer = layers[-1](x)\n",
    "model = Model(inputs=input_layer, outputs=output_layer)\n",
    "\n",
    "np.random.seed(random_seed)\n",
    "data_in = 2 * np.random.random(data_in_shape) - 1\n",
    "\n",
    "# set weights to random (use seed for reproducibility)\n",
    "weights = []\n",
    "for i, w in enumerate(model.get_weights()):\n",
    "    np.random.seed(random_seed + i)\n",
    "    weights.append(2 * np.random.random(w.shape) - 1)\n",
    "model.set_weights(weights)\n",
    "\n",
    "result = model.predict(np.array([data_in]))\n",
    "data_out_shape = result[0].shape\n",
    "data_in_formatted = format_decimal(data_in.ravel().tolist())\n",
    "data_out_formatted = format_decimal(result[0].ravel().tolist())\n",
    "\n",
    "DATA['pipeline_00'] = {\n",
    "    'input': {'data': data_in_formatted, 'shape': data_in_shape},\n",
    "    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],\n",
    "    'expected': {'data': data_out_formatted, 'shape': data_out_shape}\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### export for Keras.js tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "filename = '../../test/data/pipeline/00.json'\n",
    "if not os.path.exists(os.path.dirname(filename)):\n",
    "    os.makedirs(os.path.dirname(filename))\n",
    "with open(filename, 'w') as f:\n",
    "    json.dump(DATA, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"pipeline_00\": {\"input\": {\"data\": [0.307179, -0.769986, 0.900566, -0.035617, 0.744949, -0.575335, -0.918581, -0.205611, -0.533736, 0.683481, -0.585835, 0.484939, -0.215692, -0.635487, 0.487079, -0.860836, 0.770674, 0.905289, 0.862287, -0.169138, -0.942037, 0.964055, -0.320725, 0.413374, -0.276246, -0.929788, 0.710117, 0.314507, 0.531366, 0.108174, 0.770186, 0.808395, -0.979157, -0.850887, -0.510742, -0.73339, 0.39585, -0.20359, 0.766244, -0.637985, -0.135002, -0.963714, 0.382876, -0.060619, -0.743556, 0.782674, 0.836407, -0.853758, -0.909104, -0.122854, 0.203442, -0.379546, 0.363816, -0.581974, 0.039209, 0.131978, -0.117665, -0.724888, -0.572914, -0.733256, -0.355407, -0.532226, 0.054996, 0.131942, -0.123549, -0.356255, 0.119282, 0.730691, 0.694566, -0.784366, -0.367361, -0.181043, 0.374178, 0.404471, -0.107604, -0.159356, 0.605261, 0.077235, 0.847001, -0.876185, -0.264833, 0.940798, 0.398208, 0.781951, -0.492895, 0.451054, -0.593011, 0.075024, -0.526114, -0.127016, 0.596049, -0.383182, 0.243033, -0.120704, 0.826647, 0.317372, 0.307263, -0.283084, 0.045883, -0.909825, -0.811381, 0.843224, -0.853911, 0.858835, 0.43403, 0.244621, 0.509979, -0.71666, 0.587644, 0.450806, 0.082879, -0.034831, -0.047045, -0.107934, 0.63214, -0.451297, -0.876076, 0.920593, -0.083764, 0.515784, 0.362317, 0.083556, -0.734335, -0.493589, -0.112289, 0.1175, 0.627729, -0.364343], \"shape\": [8, 8, 2]}, \"weights\": [{\"data\": [0.307179, -0.769986, 0.900566, -0.035617, 0.744949, -0.575335, -0.918581, -0.205611, -0.533736, 0.683481, -0.585835, 0.484939, -0.215692, -0.635487, 0.487079, -0.860836, 0.770674, 0.905289, 0.862287, -0.169138, -0.942037, 0.964055, -0.320725, 0.413374, -0.276246, -0.929788, 0.710117, 0.314507, 0.531366, 0.108174, 0.770186, 0.808395, -0.979157, -0.850887, -0.510742, -0.73339, 0.39585, -0.20359, 0.766244, -0.637985, -0.135002, -0.963714, 0.382876, -0.060619, -0.743556, 0.782674, 0.836407, -0.853758, -0.909104, -0.122854, 0.203442, -0.379546, 0.363816, -0.581974, 0.039209, 0.131978, -0.117665, -0.724888, -0.572914, -0.733256, -0.355407, -0.532226, 0.054996, 0.131942, -0.123549, -0.356255, 0.119282, 0.730691, 0.694566, -0.784366, -0.367361, -0.181043], \"shape\": [3, 3, 2, 4]}, {\"data\": [-0.387536, -0.469873, -0.60788, -0.138957], \"shape\": [4]}, {\"data\": [-0.742023, -0.077688, -0.167692, 0.205448, -0.633864, -0.164175, -0.731823, 0.313236, 0.613465, -0.723716, -0.299231, 0.229032, 0.102561, 0.384949, -0.90948, -0.294898, -0.916217, -0.699031, -0.323329, -0.673445, 0.521949, -0.306796, -0.476018, -0.628623, 0.808028, -0.585043, -0.307429, -0.234868, -0.897584, 0.741743, 0.320785, 0.709132, -0.978084, 0.601894, -0.228816, -0.069558, -0.522066, -0.399597, -0.916222, 0.161549, -0.211915, 0.823372, -0.6549, -0.30403, 0.677588, -0.431259, 0.219659, -0.091937, -0.101636, -0.595218, -0.815428, 0.502932, 0.775249, 0.624226, 0.622601, -0.091075, 0.763603, 0.472659, 0.621131, -0.504549, -0.270214, 0.492749, 0.643055, -0.290058, -0.752162, 0.758918, 0.011832, -0.183967, 0.768298, 0.764241, 0.906398, 0.872853, -0.292238, 0.16788, -0.447741, 0.679196, 0.566614, 0.867549, -0.011606, -0.252108, 0.165669, -0.509362, 0.620632, -0.32465, -0.071143, -0.823613, 0.331067, -0.016903, -0.76138, -0.491146, 0.106088, -0.641492, 0.234893, 0.658853, -0.475623, 0.269103, 0.935505, -0.577134, 0.985015, -0.405957, -0.325882, 0.849518, -0.589155, 0.378331, -0.753075, 0.711411, 0.04547, 0.398327, -0.665657, 0.531142, -0.410293, -0.526649, 0.860648, 0.32795, -0.197082, -0.095526, -0.391361, 0.785465, -0.267269, -0.020154, -0.95189, -0.580742, 0.788104, -0.092433, 0.320354, 0.070651, 0.045416, 0.99799, 0.583116, -0.708131, -0.104784, -0.838947, -0.598224, 0.209105, 0.824956, 0.10438, 0.692046, -0.091308, 0.884896, 0.730617, 0.244486, -0.415624, -0.397714, -0.647236], \"shape\": [3, 3, 4, 4]}, {\"data\": [0.195612, -0.128132, -0.96626, 0.193375], \"shape\": [4]}, {\"data\": [-0.922097, 0.712992, 0.493001, 0.727856, 0.119969, -0.839034, -0.536727, -0.515472, 0.231, 0.214218, -0.791636, -0.148304, 0.309846, 0.742779, -0.123022, 0.427583, -0.882276, 0.818571, 0.043634, 0.454859, -0.007311, -0.744895, -0.368229, 0.324805, -0.388758, -0.556215, -0.542859, 0.685655, 0.350785, -0.312753, 0.591401, 0.95999, 0.136369, -0.58844, -0.506667, -0.208736, 0.548969, 0.653173, 0.128943, 0.180094, -0.16098, 0.208798, 0.666245, 0.347307, -0.384733, -0.88354, -0.328468, -0.515324, 0.479247, -0.360647, 0.09069, -0.221424, 0.091284, 0.202631, 0.208087, 0.582248, -0.164064, -0.925036, -0.678806, -0.212846, 0.960861, 0.536089, -0.038634, -0.473456, -0.409408, 0.620315, -0.873085, -0.695405, -0.024465, 0.762843, -0.928228, 0.557106, -0.65499, -0.918356, 0.815491, 0.996431, 0.115769, -0.751652, 0.075229, 0.969983, -0.80409, -0.080661, -0.644088, 0.160702, -0.486518, -0.09818, -0.191651, -0.961566, -0.238209, 0.260427, 0.085307, -0.664437, 0.458517, -0.824692, 0.312768, -0.253698, 0.761718, 0.551215, 0.566009, -0.85706, 0.687904, -0.283819, 0.5816, 0.820087, -0.028474, 0.588153, -0.221145, 0.049173, 0.529328, -0.359074, -0.463161, 0.493967, -0.852793, -0.552675, -0.695748, -0.178157, 0.477995, 0.858725, 0.120384, -0.515209, 0.204484, -0.025025, -0.654961, 0.239585, -0.654691, -0.651696, -0.699951, -0.054626, -0.232999, 0.464974, 0.285499, -0.311165, 0.18009, -0.100505, 0.303943, 0.265535, -0.960747, -0.542418, 0.195178, -0.848394, 0.0774, 0.250615, -0.690541, -0.106589], \"shape\": [3, 3, 4, 4]}, {\"data\": [0.318429, -0.858397, -0.059042, 0.68597], \"shape\": [4]}], \"expected\": {\"data\": [5.009162, 0.0, 0.0, 0.0, 1.770272, 3.243442, 0.0, 3.319521, 0.0, 2.15876, 0.0, 0.0, 4.509293, 0.188208, 0.0, 0.0], \"shape\": [2, 2, 4]}}}\n"
     ]
    }
   ],
   "source": [
    "print(json.dumps(DATA))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
